# -*- coding: utf-8 -*- #
"""
Time                2023/3/19 23:41
Author:             mingfeng (SunnyQjm)
Email               mfeng@linux.alibaba.com
File                load_balancing_strategy.py
Description:
"""
import random
from abc import ABCMeta, abstractmethod
from typing import List, Optional
from enum import Enum
from .exceptions import CmgException
from .service_instance import ServiceInstance


class LoadBalancingStrategy(Enum):
    RANDOM = 0
    ROBIN = 1


def create_load_balancing_strategy(strategy: LoadBalancingStrategy, *args,
                                   **kwargs):
    if strategy == LoadBalancingStrategy.RANDOM:
        return RandomLoadBalancingStrategy(*args, **kwargs)
    elif strategy == LoadBalancingStrategy.ROBIN:
        return RobinLoadBalancingStrategy(*args, **kwargs)
    else:
        raise CmgException(
            f"Not support load balancing strategy => "
            f"{strategy.name}:{strategy.value}"
        )


class LoadBalancingStrategyBase(metaclass=ABCMeta):
    @abstractmethod
    def update(self, instances: List[ServiceInstance]):
        pass

    @abstractmethod
    def select(self) -> ServiceInstance:
        pass

    @abstractmethod
    def type(self) -> LoadBalancingStrategy:
        pass


class RandomLoadBalancingStrategy(LoadBalancingStrategyBase):

    def __init__(self):
        self.instances = []

    def update(self, instances: List[ServiceInstance]):
        self.instances = instances

    def select(self) -> Optional[ServiceInstance]:
        if len(self.instances) == 0:
            return None
        return self.instances[random.randint(0, len(self.instances) - 1)]

    def type(self) -> LoadBalancingStrategy:
        return LoadBalancingStrategy.RANDOM


class RobinLoadBalancingStrategy(LoadBalancingStrategyBase):
    def __init__(self):
        self.instances = []
        self.cur_idx = 0

    def update(self, instances: List[ServiceInstance]):
        self.instances = instances
        if len(self.instances) > self.cur_idx + 1:
            self.cur_idx = 0

    def select(self) -> Optional[ServiceInstance]:
        if len(self.instances) == 0:
            return None
        old_idx = self.cur_idx
        self.cur_idx = (self.cur_idx + 1) % len(self.instances)
        return self.instances[old_idx]

    def type(self) -> LoadBalancingStrategy:
        return LoadBalancingStrategy.ROBIN
